{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1fb76974-93ea-4b9c-81b1-55f826e7a361",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################################################################################\n",
    "# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
    "########################################################################################################\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
    "import types, torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "MyModule = torch.jit.ScriptModule\n",
    "MyFunction = torch.jit.script_method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e6c9297-472f-4fd8-ad19-d8072b5040f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "rwkv5又叫eagal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1059eca-db4f-4c0b-ae3e-37af49ec7fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c8d8009-7ee7-4419-aacb-cdc45f287010",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RWKV_TOKENIZER():\n",
    "    table: list[list[list[bytes]]]\n",
    "    good: list[set[int]]\n",
    "    wlen: list[int]\n",
    "    def __init__(self, file_name):\n",
    "        self.idx2token = {}\n",
    "        sorted = [] # must be already sorted\n",
    "        lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
    "        for l in lines:\n",
    "            idx = int(l[:l.index(' ')])\n",
    "            x = eval(l[l.index(' '):l.rindex(' ')])\n",
    "            x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
    "            assert isinstance(x, bytes)\n",
    "            assert len(x) == int(l[l.rindex(' '):])\n",
    "            sorted += [x]\n",
    "            self.idx2token[idx] = x\n",
    "\n",
    "        self.token2idx = {}\n",
    "        for k, v in self.idx2token.items():\n",
    "            self.token2idx[v] = int(k)\n",
    "\n",
    "        # precompute some tables for fast matching\n",
    "        self.table = [[[] for j in range(256)] for i in range(256)]\n",
    "        self.good = [set() for i in range(256)]\n",
    "        self.wlen = [0 for i in range(256)]\n",
    "\n",
    "        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
    "            s = sorted[i]\n",
    "            if len(s) >= 2:\n",
    "                s0 = int(s[0])\n",
    "                s1 = int(s[1])\n",
    "                self.table[s0][s1] += [s]\n",
    "                self.wlen[s0] = max(self.wlen[s0], len(s))\n",
    "                self.good[s0].add(s1)\n",
    "\n",
    "    def encodeBytes(self, src: bytes) -> list[int]:\n",
    "        src_len: int = len(src)\n",
    "        tokens: list[int] = []\n",
    "        i: int = 0\n",
    "        while i < src_len:\n",
    "            s: bytes = src[i : i + 1]\n",
    "\n",
    "            if i < src_len - 1:\n",
    "                s1: int = int(src[i + 1])\n",
    "                s0: int = int(src[i])\n",
    "                if s1 in self.good[s0]:\n",
    "                    sss: bytes = src[i : i + self.wlen[s0]]\n",
    "                    try:\n",
    "                        s = next(filter(sss.startswith, self.table[s0][s1]))\n",
    "                    except:\n",
    "                        pass\n",
    "            tokens.append(self.token2idx[s])\n",
    "            i += len(s)\n",
    "\n",
    "        return tokens\n",
    "\n",
    "    def decodeBytes(self, tokens):\n",
    "        return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
    "\n",
    "    def encode(self, src: str):\n",
    "        return self.encodeBytes(src.encode(\"utf-8\"))\n",
    "\n",
    "    def decode(self, tokens):\n",
    "        return self.decodeBytes(tokens).decode('utf-8')\n",
    "\n",
    "    def printTokens(self, tokens):\n",
    "        for i in tokens:\n",
    "            s = self.idx2token[i]\n",
    "            try:\n",
    "                s = s.decode('utf-8')\n",
    "            except:\n",
    "                pass\n",
    "            print(f'{repr(s)}{i}', end=' ')\n",
    "            # print(repr(s), i)\n",
    "        print()\n",
    "\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "63a4e8ba-a291-4fdc-aef1-ebfca21840d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_logits(out, temperature=1.0, top_p=0.8):\n",
    "    probs = F.softmax(out, dim=-1).numpy()\n",
    "    sorted_probs = np.sort(probs)[::-1]\n",
    "    cumulative_probs = np.cumsum(sorted_probs)\n",
    "    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
    "    probs[probs < cutoff] = 0\n",
    "    if temperature != 1.0:\n",
    "        probs = probs.pow(1.0 / temperature)\n",
    "    probs = probs / np.sum(probs)\n",
    "    out = np.random.choice(a=len(probs), p=probs)\n",
    "    return out\n",
    "\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "raw",
   "id": "cb8c7d5e-08cb-4780-b6d9-ab8bad1417d4",
   "metadata": {},
   "source": [
    "可以从这个链接下载模型：\n",
    "https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files\n",
    "https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "94d7d6db-e89e-4209-ae72-6625ba85ef5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
    "\n",
    "# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS\n",
    "\n",
    "args = types.SimpleNamespace()\n",
    "args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth\n",
    "args.n_layer = 24\n",
    "args.n_embd = 1024\n",
    "args.vocab_size = 65536"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "c8dcf39a-7838-454b-85fc-ec9bd75fa243",
   "metadata": {},
   "outputs": [],
   "source": [
    "# N_LAYER=\"12\"\n",
    "# N_EMBD=\"768\"\n",
    "N_LAYER=\"24\"\n",
    "N_EMBD=\"1024\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "74d7c96a-6fbc-401c-8078-fefb1a6ec5c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# context = \"\\nElon Musk has\"\n",
    "# context = \"\\n我们发现\"\n",
    "context = \"Q:Do you know datawhalechina?\\nA:\"\n",
    "NUM_TRIALS = 3\n",
    "LENGTH_PER_TRIAL = 100\n",
    "LENGTH_PER_TRIAL = 4096\n",
    "TEMPERATURE = 1.0\n",
    "TOP_P = 0.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "bd093a96-fdc5-460d-b39f-fe3735795b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RWKV_RNN(MyModule):\n",
    "    def __init__(self, args):\n",
    "        super().__init__()\n",
    "        self.args = args\n",
    "        self.eval() # set torch to inference mode\n",
    "        \n",
    "        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
    "        for k in w.keys():\n",
    "            w[k] = w[k].float() # convert to f32 type\n",
    "            if      '.time_' in k: w[k] = w[k].squeeze()\n",
    "            if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)\n",
    "            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
    "\n",
    "        self.n_head = w['blocks.0.att.time_decay'].shape[0]\n",
    "        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
    "        \n",
    "        self.w = types.SimpleNamespace() # set self.w from w\n",
    "        self.w.blocks = {}\n",
    "        for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
    "            parts = k.split('.')\n",
    "            last = parts.pop()\n",
    "            here = self.w\n",
    "            for p in parts:\n",
    "                if p.isdigit():\n",
    "                    p = int(p)\n",
    "                    if p not in here: here[p] = types.SimpleNamespace()\n",
    "                    here = here[p]\n",
    "                else:\n",
    "                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
    "                    here = getattr(here, p)\n",
    "            setattr(here, last, w[k])\n",
    "\n",
    "    def layer_norm(self, x, w):\n",
    "        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
    "\n",
    "    @MyFunction\n",
    "    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "        i0 = (2+self.head_size)*i+0\n",
    "        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
    "        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
    "        state[i0] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
    "        return r * (vw @ k)\n",
    "\n",
    "    @MyFunction\n",
    "    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
    "        H = self.n_head\n",
    "        S = self.head_size\n",
    "\n",
    "        i1 = (2+S)*i+1\n",
    "        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
    "        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
    "        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
    "        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
    "        state[i1] = x\n",
    "\n",
    "        r = (rw @ xr).view(H, 1, S)\n",
    "        k = (kw @ xk).view(H, S, 1)\n",
    "        v = (vw @ xv).view(H, 1, S)\n",
    "        g = F.silu(gw @ xg)\n",
    "\n",
    "        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
    "\n",
    "        x = torch.zeros(H, S)\n",
    "        a = k @ v\n",
    "        x = r @ (time_first * a + s)\n",
    "        s = a + time_decay * s\n",
    "    \n",
    "        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
    "        x = x.flatten()\n",
    "\n",
    "        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
    "        return ow @ x\n",
    "\n",
    "    def forward(self, token, state):\n",
    "        with torch.no_grad():\n",
    "            if state == None:\n",
    "                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
    "            \n",
    "            x = self.w.emb.weight[token]\n",
    "            x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
    "            for i in range(self.args.n_layer):\n",
    "                # print(i)\n",
    "                att = self.w.blocks[i].att\n",
    "                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
    "                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, \n",
    "                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
    "                    att.ln_x.weight, att.ln_x.bias)\n",
    "                ffn = self.w.blocks[i].ffn\n",
    "                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
    "                    ffn.time_mix_k, ffn.time_mix_r, \n",
    "                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
    "            \n",
    "            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
    "            return x.float(), state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "a330cd34-7ed0-4a6c-92a3-19797d34ee77",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = \"Q:Do you know datawhalechina?\\nA:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "ad824161-413d-460c-9ffe-9dbfb739f86b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.MODEL_NAME"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "f0e2f841-4cda-48d4-b055-7adf00f2fe73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24, 1024)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.n_layer,args.n_embd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "aba8a4d4-9a77-4191-a7ef-d5e6100ca3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.n_layer = 24\n",
    "# args.n_embd = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "dd44f7bc-e8d6-4242-beb5-89a866990751",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.n_layer = 12\n",
    "# args.n_embd = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "2a96d9dc-8b5e-40cc-bb36-24c9bdeac29e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.MODEL_NAME='../models/rwkv-5-world-1b5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "b7d07606-31b4-4c21-9f89-554d89c2c866",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...\n",
      "\n",
      "Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
     ]
    }
   ],
   "source": [
    "print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
    "model = RWKV_RNN(args)\n",
    "\n",
    "print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
    "init_state = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "ce42cfad-0274-4d5d-950d-fb89ff11ed2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_state = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "3e02a81c-1447-4936-a241-4d00ecf8e862",
   "metadata": {},
   "outputs": [],
   "source": [
    "LENGTH_PER_TRIAL=1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "4a00ea05-d6fd-4052-b13a-8107fb268420",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "--[ Trial 0 ]----------------- Q:Do you know datawhalechina?\n",
      "A:If you have your money, go to the blue whale website. It has information on whales and its feeding habits. The info is up-to-date and well-researched.\n",
      "https://www.who.int/news-room/fact-sheets/detail/blue-whale\n",
      "It's a nice fish. It's great for eating.\n",
      "\n",
      "--[ Trial 1 ]----------------- Q:Do you know datawhalechina?\n",
      "A:A very old and very great one.\n",
      "http://www.cnn.com/2008/WORLD/asia/10/17/china.beach.tourist.casino.cnn/index.html\n",
      "\n",
      "--[ Trial 2 ]----------------- Q:Do you know datawhalechina?\n",
      "A:The datawhale china is a chinese based data analytics and decision making company, they work with many big companies in China. They work with large numbers of large companies in china. They can provide companies with data, analytics and knowledge about companies.\n",
      "Q:How do you find datawhale china?\n",
      "A:We use a variety of sources to find companies in china. We search for companies based on a variety of criteria. We look for companies with a specific industry or product. We also use a variety of data sources to find companies in china.\n",
      "Q:What kind of data do you use to find companies in china?\n",
      "A:We use a variety of data sources to find companies in china. We look for companies with a specific industry or product. We also use a variety of sources to find companies in china.\n",
      "Q:How do you know if a company is in china?\n",
      "A:We use a variety of sources to find companies in china. We look for companies based on a variety of criteria. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the advantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What is the purpose of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the advantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the advantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the advantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "Q:What are some of the disadvantages of using datawhale china?\n",
      "A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for token in tokenizer.encode(context):\n",
    "    init_out, init_state = model.forward(token, init_state)\n",
    "\n",
    "for TRIAL in range(NUM_TRIALS):\n",
    "    print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
    "    all_tokens = []\n",
    "    out_last = 0\n",
    "    out, state = init_out.clone(), init_state.clone()\n",
    "    for i in range(LENGTH_PER_TRIAL):\n",
    "        token = sample_logits(out, TEMPERATURE, TOP_P)\n",
    "        all_tokens += [token]\n",
    "        try:\n",
    "            tmp = tokenizer.decode(all_tokens[out_last:])\n",
    "            if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
    "                print(tmp, end=\"\", flush=True)\n",
    "                out_last = i + 1\n",
    "        except:\n",
    "            pass\n",
    "        out, state = model.forward(token, state)       \n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfb6dc7a-786f-427b-9aa9-b51f2771b0cc",
   "metadata": {},
   "source": [
    "显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的，这个没得说，在cpu上跑，资源消耗也很小。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kewei-ai",
   "language": "python",
   "name": "kewei-ai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
